import numpy as np
from nlu_model.util.pkl_impl import save_pkl, load_pkl
from gensim.models.word2vec import Word2Vec
from gensim.models import KeyedVectors

class Word2vector(object):
    """docstring for Word2vector"""
    def __init__(self):
        self.word2idx_dic = {}
        self.embedding_weights = []

    def train(self, 
              train_data,                          # 训练数据
              N_DIM = 300,                         # word2vec的数量
              MIN_COUNT = 5,                       # 保证出现的词数足够做才进入词典
              w2v_EPOCH = 15,                      # w2v的训练迭代次数
              MAXLEN = 50                          # 句子最大长度
              ):

        self.N_DIM = N_DIM
        self.MIN_COUNT = MIN_COUNT
        self.w2v_EPOCH = w2v_EPOCH
        self.MAXLEN = MAXLEN

        # Initialize model and build vocab
        imdb_w2v = Word2Vec(size=N_DIM, min_count=MIN_COUNT)
        imdb_w2v.build_vocab(train_data)

        # Train the model over train_reviews (this may take several minutes)
        imdb_w2v.train(train_data, total_examples=len(train_data), epochs=w2v_EPOCH)

        # imdb_w2v.save('./data/ptm/shopping_reviews/w2v_model2020100601.pkl')
        print("model train done")

        # word2vec后处理
        n_symbols = len(imdb_w2v.wv.vocab.keys()) + 2
        embedding_weights = [[0 for i in range(N_DIM)] for i in range(n_symbols)]
        np.zeros((n_symbols, 300))
        idx = 1
        word2idx_dic = {}
        w2v_model_metric = []
        for w in imdb_w2v.wv.vocab.keys():
            embedding_weights[idx] = imdb_w2v[w]
            word2idx_dic[w] = idx
            idx = idx + 1

        # 留给未登录词的位置
        avg_weights = [0 for i in range(N_DIM)]
        for wd in word2idx_dic:
            avg_weights = [(avg_weights[idx]+embedding_weights[word2idx_dic[wd]][idx]) for idx in range(N_DIM)]
        avg_weights = [avg_weights[idx] / len(word2idx_dic) for idx in range(N_DIM)]
        embedding_weights[idx] = avg_weights
        word2idx_dic["<UNK>"] = idx

        # 留给pad的位置
        word2idx_dic["<PAD>"] = 0

        self.word2idx_dic = word2idx_dic
        self.embedding_weights = embedding_weights

    def save(self,
             word2idx_dic_path,                   # 词到ID词典路径
             embedding_path,                      # embedding词向量路径
             model_conf_path                     # 模型配置加载)
             ):
            # 保存w2id词典
            save_pkl(word2idx_dic_path, self.word2idx_dic)

            # 保存词向量矩阵
            save_pkl(embedding_path, self.embedding_weights)

            # 保存配置
            save_pkl(model_conf_path, [self.N_DIM, self.MIN_COUNT, self.w2v_EPOCH, self.MAXLEN])

    def __load_default__(self):
        self.load("./data/ptm/shopping_reviews/w2v_word2idx2020100601.pkl",
                  "./data/ptm/shopping_reviews/w2v_model_metric_2020100601.pkl", 
                  "./data/ptm/shopping_reviews/w2v_model_conf_2020100601.pkl")

    def load(self, word2idx_dic_path, embedding_path, model_conf_path):
        self.N_DIM, self.MIN_COUNT, self.w2v_EPOCH, self.MAXLEN = load_pkl(model_conf_path)
        self.embedding_weights = load_pkl(embedding_path)
        self.word2idx_dic = load_pkl(word2idx_dic_path)


    def word2idx(self, word):
        if len(self.word2idx_dic) == 0:
            self.__load_default__()
        if word in self.word2idx_dic:
            return self.word2idx_dic[word]
        else:
            return len(self.word2idx_dic) - 1

    def sentence2idx(self, sentence, batch_len = -1):
        sentence_idx = []
        for idx in range(len(sentence)):
            sentence_idx.append(self.word2idx(sentence[idx]))
        if batch_len >= 0:
            if len(sentence_idx) > batch_len:
                sentence_idx = sentence_idx[:batch_len]
            else:
                while len(sentence_idx) < batch_len:
                    sentence_idx.append(0)
        return sentence_idx

    def batch2idx(self, source_data, batch_len = -1):
        result_data = []
        for idx in range(len(source_data)):
            result_data.append(self.sentence2idx(source_data[idx], batch_len))
        return result_data

    def get_np_weights(self):
        return np.array(self.embedding_weights)

    def load_keyword_type(self, path):
        imdb_w2v = KeyedVectors.load_word2vec_format(path,
                                                     binary=False, encoding="utf8",  unicode_errors='ignore')

        self.N_DIM = imdb_w2v.vector_size
        self.MIN_COUNT = -1
        self.w2v_EPOCH = -1
        self.MAXLEN = -1

        # word2vec后处理
        n_symbols = len(imdb_w2v.wv.vocab.keys()) + 2
        embedding_weights = [[0 for i in range(self.N_DIM)] for i in range(n_symbols)]
        np.zeros((n_symbols, 300))
        idx = 1
        word2idx_dic = {}
        w2v_model_metric = []
        for w in imdb_w2v.wv.vocab.keys():
            embedding_weights[idx] = imdb_w2v[w]
            word2idx_dic[w] = idx
            idx = idx + 1

        # 留给未登录词的位置
        avg_weights = [0 for i in range(self.N_DIM)]
        for wd in word2idx_dic:
            avg_weights = [(avg_weights[idx]+embedding_weights[word2idx_dic[wd]][idx]) for idx in range(self.N_DIM)]
        avg_weights = [avg_weights[idx] / len(word2idx_dic) for idx in range(self.N_DIM)]
        embedding_weights[idx] = avg_weights
        word2idx_dic["<UNK>"] = idx

        # 留给pad的位置
        word2idx_dic["<PAD>"] = 0

        self.word2idx_dic = word2idx_dic
        self.embedding_weights = embedding_weights

if __name__ == '__main__':
    w2v = Word2vector()
    w2v.load_keyword_type("./data/ptm/weibo_w2v_300/sgns.weibo.char")
    # w2v.save()
    w2v.save("./data/ptm/weibo_w2v_300/w2v_word2idx2020101701.pkl",
          "./data/ptm/weibo_w2v_300/w2v_model_metric_2020101701.pkl", 
          "./data/ptm/weibo_w2v_300/w2v_model_conf_2020101701.pkl")
    print("preprocess success")



